{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "MNIST_on_TPU.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ECija61OXJT4"
      },
      "source": [
        "## Getting Started with PyTorch-Ignite on Cloud TPUs\n",
        "\n",
        "This notebook is based on [\"Getting Started with PyTorch on Cloud TPUs\"](https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb#scrollTo=RKLajLqUni6H) and will show you how to:\n",
        "\n",
        "- Install PyTorch/XLA on Colab, which lets you use PyTorch with TPUs.\n",
        "- Train a basic model on MNIST with PyTorch-Ignite.\n",
        "\n",
        "PyTorch/XLA is a package that lets PyTorch connect to Cloud TPUs and use TPU cores as devices. Colab provides a free Cloud TPU system (a remote CPU host + four TPU chips with two cores each) and installing PyTorch/XLA only takes a couple minutes.\n",
        "\n",
        "\n",
        "<h3>  &nbsp;&nbsp;Use Colab Cloud TPU&nbsp;&nbsp; <a href=\"https://cloud.google.com/tpu/\"><img valign=\"middle\" src=\"https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png\" width=\"50\"></a></h3>\n",
        "\n",
        "* On the main menu, click Runtime and select **Change runtime type**. Set \"TPU\" as the hardware accelerator.\n",
        "* The cell below makes sure you have access to a TPU on Colab.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0nIBhUekU593"
      },
      "source": [
        "import os\n",
        "assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v-_pVHuPW_zn"
      },
      "source": [
        "## Installing PyTorch/XLA\n",
        "\n",
        "Run the following cell (or copy it into your own notebook!) to install PyTorch, Torchvision, and PyTorch/XLA. It will take a couple minutes to run.\n",
        "\n",
        "The PyTorch/XLA package lets PyTorch connect to Cloud TPUs. (It's named PyTorch/XLA, not PyTorch/TPU, because XLA is the name of the TPU compiler.) In particular, PyTorch/XLA makes TPU cores available as PyTorch devices. This lets PyTorch create and manipulate tensors on TPUs."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4mnaxX3eXNGd"
      },
      "source": [
        "VERSION = !curl -s https://api.github.com/repos/pytorch/xla/releases/latest | grep -Po '\"tag_name\": \"v\\K.*?(?=\")'\n",
        "VERSION = VERSION[0].rstrip('0').rstrip('.') # remove trailing zero\n",
        "!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-{VERSION}-cp37-cp37m-linux_x86_64.whl"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iKFb5vdtgODx"
      },
      "source": [
        "## Required Dependencies\n",
        "\n",
        "We assume that `torch` and `ignite` are already installed. We can install it using `pip`:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6bsFP_2vcTlP"
      },
      "source": [
        "!pip install pytorch-ignite"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CHYlKKHUaoOm"
      },
      "source": [
        "## Train a basic model on MNIST with PyTorch-Ignite."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0ki0DNMKaN9S"
      },
      "source": [
        "PyTorch XLA API is so simple as PyTorch uses Cloud TPUs just like it uses CPU or CUDA devices. With only minor changes we can train models with PyTorch and Ignite. We will use the code of this example : https://github.com/pytorch/ignite/blob/master/examples/mnist/mnist_with_tensorboard_on_tpu.py\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gn4XhksofbPd"
      },
      "source": [
        "### Import librairies"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EFAJ65OmZPjK"
      },
      "source": [
        "import torch\n",
        "print(\"PyTorch version:\", torch.__version__)\n",
        "\n",
        "# imports the torch_xla package\n",
        "import torch_xla\n",
        "import torch_xla.core.xla_model as xm\n",
        "print(\"PyTorch xla version:\", torch_xla.__version__)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KW4s6SiOVEwF"
      },
      "source": [
        "# Import PyTorch, Torchvision and Tensorboard\n",
        "from torch.utils.data import DataLoader\n",
        "from torch import nn\n",
        "import torch.nn.functional as F\n",
        "from torch.optim import SGD\n",
        "from torchvision.datasets import MNIST\n",
        "from torchvision.transforms import Compose, ToTensor, Normalize\n",
        "\n",
        "from torch.utils.tensorboard import SummaryWriter\n",
        "\n",
        "# Import PyTorch-Ignite\n",
        "from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator\n",
        "from ignite.metrics import Accuracy, Loss, RunningAverage\n",
        "from ignite.handlers import ProgressBar"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ksx7UGGZhHGD"
      },
      "source": [
        "### Data processing"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QnT6QuTkhC2d"
      },
      "source": [
        "# Dataloaders\n",
        "def get_data_loaders(train_batch_size, val_batch_size):\n",
        "    data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])\n",
        "\n",
        "    train_loader = DataLoader(\n",
        "        MNIST(download=True, root=\".\", transform=data_transform, train=True), batch_size=train_batch_size, shuffle=True\n",
        "    )\n",
        "\n",
        "    val_loader = DataLoader(\n",
        "        MNIST(download=False, root=\".\", transform=data_transform, train=False), batch_size=val_batch_size, shuffle=False\n",
        "    )\n",
        "    return train_loader, val_loader"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EG9C4-yShJ5o"
      },
      "source": [
        "train_batch_size = 64\n",
        "val_batch_size = train_batch_size * 2\n",
        "\n",
        "train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hUSL88A5giaO"
      },
      "source": [
        "### Create a model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Biefe9RaZb7_"
      },
      "source": [
        "# Setup a basic CNN\n",
        "class Net(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(Net, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
        "        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
        "        self.conv2_drop = nn.Dropout2d()\n",
        "        self.fc1 = nn.Linear(320, 50)\n",
        "        self.fc2 = nn.Linear(50, 10)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = F.relu(F.max_pool2d(self.conv1(x), 2))\n",
        "        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n",
        "        x = x.view(-1, 320)\n",
        "        x = F.relu(self.fc1(x))\n",
        "        x = F.dropout(x, training=self.training)\n",
        "        x = self.fc2(x)\n",
        "        return F.log_softmax(x, dim=-1)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2eN7bb8yhQbK"
      },
      "source": [
        "model = Net()\n",
        "device = xm.xla_device()\n",
        "model = model.to(device)  # Move model before creating optimizer"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "egKdsX-mhZ1h"
      },
      "source": [
        "### Optimizer and trainers"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JvCwVzkDeWVG"
      },
      "source": [
        "optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)\n",
        "\n",
        "# Create trainer and evaluator\n",
        "trainer = create_supervised_trainer(\n",
        "    model, \n",
        "    optimizer, \n",
        "    F.nll_loss, \n",
        "    device=device, \n",
        "    output_transform=lambda x, y, y_pred, loss: [loss.item(), ]\n",
        ")\n",
        "\n",
        "evaluator = create_supervised_evaluator(\n",
        "    model, metrics={\"accuracy\": Accuracy(), \"nll\": Loss(F.nll_loss)}, device=device\n",
        ")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "clTM9TrThd7L"
      },
      "source": [
        "### Handlers"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "X8hiRkm2cr8m"
      },
      "source": [
        "# Setup event's handlers\n",
        "log_interval = 10\n",
        "\n",
        "log_dir = \"/tmp/tb_logs\"\n",
        "\n",
        "# writer\n",
        "writer = SummaryWriter(log_dir=log_dir)\n",
        "\n",
        "tracker = xm.RateTracker()\n",
        "\n",
        "# Add RateTracker as an output of the training step\n",
        "@trainer.on(Events.ITERATION_COMPLETED)\n",
        "def add_rate_tracker(engine):\n",
        "    tracker.add(len(engine.state.batch))\n",
        "    engine.state.output.append(tracker.global_rate())\n",
        "\n",
        "# Setup output values of the training step as EMA metrics\n",
        "RunningAverage(output_transform=lambda x: x[0]).attach(trainer, \"batch_loss\")\n",
        "RunningAverage(output_transform=lambda x: x[1]).attach(trainer, \"global_rate\")\n",
        "\n",
        "# Let's log the EMA metrics every `log_interval` iterations\n",
        "@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))\n",
        "def log_training_loss(engine):\n",
        "    writer.add_scalar(\"training/batch_loss\", engine.state.metrics[\"batch_loss\"], engine.state.iteration)\n",
        "    writer.add_scalar(\"training/global_rate\", engine.state.metrics[\"global_rate\"], engine.state.iteration)\n",
        "\n",
        "# Setup a progress bar (tqdm) and display batch loss metric in the bar\n",
        "pbar = ProgressBar()\n",
        "pbar.attach(trainer, [\"batch_loss\", \"global_rate\"])\n",
        "\n",
        "# Let's compute training metrics: average accuracy and loss\n",
        "@trainer.on(Events.EPOCH_COMPLETED)\n",
        "def log_training_results(engine):\n",
        "    evaluator.run(train_loader)\n",
        "    metrics = evaluator.state.metrics\n",
        "    avg_accuracy = metrics[\"accuracy\"]\n",
        "    avg_nll = metrics[\"nll\"]\n",
        "    pbar.log_message(\n",
        "        f\"Training Results - Epoch: {engine.state.epoch}  Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}\"\n",
        "    )\n",
        "    writer.add_scalar(\"training/avg_loss\", avg_nll, engine.state.epoch)\n",
        "    writer.add_scalar(\"training/avg_accuracy\", avg_accuracy, engine.state.epoch)\n",
        "\n",
        "# Let's compute training metrics: average accuracy and loss\n",
        "@trainer.on(Events.EPOCH_COMPLETED)\n",
        "def log_validation_results(engine):\n",
        "    evaluator.run(val_loader)\n",
        "    metrics = evaluator.state.metrics\n",
        "    avg_accuracy = metrics[\"accuracy\"]\n",
        "    avg_nll = metrics[\"nll\"]\n",
        "    print(\n",
        "        f\"Validation Results - Epoch: {engine.state.epoch}  Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}\"\n",
        "    )\n",
        "    writer.add_scalar(\"valdation/avg_loss\", avg_nll, engine.state.epoch)\n",
        "    writer.add_scalar(\"valdation/avg_accuracy\", avg_accuracy, engine.state.epoch)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IxFocNa7hnhy"
      },
      "source": [
        "# Display in Firefox may not work properly. Use Chrome.\n",
        "%load_ext tensorboard\n",
        "\n",
        "%tensorboard --logdir=\"/tmp/tb_logs\""
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qbC-Th_JcsS1"
      },
      "source": [
        "# kick everything off\n",
        "!rm -rf /tmp/tb_logs/*\n",
        "\n",
        "trainer.run(train_loader, max_epochs=10)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}
